package com.smart5G.utils;

import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

@Intercepts({@Signature(type= StatementHandler.class, method = "prepare", args={Connection.class, Integer.class})})
public class PageInterceptor implements Interceptor {

    private String sqlRegEx = ".*Page";

    public Object intercept(Invocation invocation) throws Throwable {
        RoutingStatementHandler handler = (RoutingStatementHandler)invocation.getTarget();
        StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
        BoundSql boundSql = delegate.getBoundSql();
        MappedStatement mappedStatement = (MappedStatement)ReflectUtil.getFieldValue(delegate, "mappedStatement");
        // 获取参数
        Object parameterObject = boundSql.getParameterObject();
        // 判断是否分页
        try{
            if (PageUtil.isNeedPage()) {
                Connection connection = (Connection) invocation.getArgs()[0];
                // 获取mapper映射文件中对应的sql语句
                String sql = boundSql.getSql();
                // 给当前page参数设置总记录数
                this.setPageParameter(mappedStatement, connection, boundSql);
                // 获取分页sql语句
                String pageSql = this.getPageSql(sql);
                ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
            }
        }finally {
            PageUtil.noPage();
        }

        return invocation.proceed();
    }

    /**
     * 从数据库里查询总的记录数并计算总页数，回写进分页参数page
     * @param mappedStatement
     * @param connection
     * @param boundSql
     */
    private void setPageParameter(MappedStatement mappedStatement, Connection connection, BoundSql boundSql) {
        // 获取mapper映射文件中对应的sql语句
        String sql = boundSql.getSql();
        // 获取计算总记录数的sql语句
        String countSql = this.getCountSql(sql);
        // 获取BoundSql参数映射
        List<ParameterMapping> parameterMappinglist = boundSql.getParameterMappings();
        // 构造查询总量的BoundSql
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappinglist, boundSql.getParameterObject());
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)){
                countBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), countBoundSql);

        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            // 通过connection建立countSql对应的PreparedStatement对象
            pstmt = connection.prepareStatement(countSql);
            parameterHandler.setParameters(pstmt);
            // 执行countSql语句
            rs = pstmt.executeQuery();
            if (rs.next()) {
                int totalRecord = rs.getInt(1);
                PageUtil.setTotalRecord(totalRecord);
                PageUtil.setTotalPage(totalRecord/PageUtil.getPageSize() + (totalRecord % PageUtil.getPageSize() == 0? 0: 1));
            }
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    /**
     * 根据源sql语句获取对应的查询总记录数的sql语句
     * @param sql
     * @return
     */
    private String getCountSql(String sql) {
        int index = sql.indexOf("from");
        return "select count(*) " + sql.substring(index);
    }

    /**
     * 获取MySql数据库的分页查询语句
     * @param sql
     * @return
     */
    private String getPageSql(String sql) {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        int offset = (PageUtil.getPageNum() - 1) * PageUtil.getPageSize();
        sqlBuffer.append(" limit ").append(offset).append(",").append(PageUtil.getPageSize());
        return sqlBuffer.toString();
    }

    /**
     * 只处理StatementHandler类型
     * @param o
     * @return
     */
    public Object plugin(Object o) {
        if (o instanceof  StatementHandler) {
            return Plugin.wrap(o, this);
        } else {
            return o;
        }
    }

    /**
     * 拦截器属性设定
     * @param properties
     */
    public void setProperties(Properties properties) {
    }

    public String getSqlRegEx() {
        return sqlRegEx;
    }

    public void setSqlRegEx(String sqlRegEx) {
        this.sqlRegEx = sqlRegEx;
    }
}
